#include "SSH.h"

namespace Upp {

#define LLOG(x)	 RLOG(x)

Ssh& Ssh::StartConnect(const String& host, int port)
{
	// Init
	AddJob() << [=] {
		if(host.IsEmpty())
			Error(-1, t_("Host is not specified"));
		else
		if(username.IsEmpty() || password.IsEmpty())
			Error(-1, t_("Username or password is not specified."));
		socket.Clear();
		socket.WhenWait = Proxy(WhenDo);
		session = NULL;
		connected = false;
		ip_addrinfo.Start(host, port);
		LLOG(Format("** SSH2: Starting DNS sequence locally for %s:%d", host, port));
		return false;
	};

	// DNS lookup & socket creation.
	AddJob() << [=] {
		if(ip_addrinfo.InProgress())
			return true;
		if(!ip_addrinfo.GetResult())
			Error(-1, Format(t_("DNS query for '%s' failed."), host));
		return false;
	};

	// Connect to SSH server and init libssh2 session
	AddJob() << [=] {
		if(socket.Connect(ip_addrinfo) && socket.WaitConnect()) {
			LLOG(Format("++ SSH2: Successfully connected to %s:%d", host, port));
			ip_addrinfo.Clear();
			// Maybe we should make the memory managers (system/upp) switchable on runtime?
#ifdef flagUSEMALLOC
			LLOG("** SSH2: Using libssh2's own memory manager.");
			session = libssh2_session_init(NULL, NULL, NULL, this);
#else
			LLOG("** SSH2: Using U++ style memory managers.");
			session = libssh2_session_init_ex((*ssh_malloc), (*ssh_free), (*ssh_realloc), this);
#endif
			if(!session)
				Error(-1, t_("Unable to initalize libssh2 session."));
			libssh2_session_set_blocking(session, 0);
			WhenConfig(*this);
			return false;
		}
		else
		if(socket.IsError())
			Error(-1, t_("Couldn't connect to ") << host);
		return true;
	};

	// Set up transport methods.
	AddJob() << [=] {
		if(ssh_methods.IsEmpty())
			return false;
		auto rc = libssh2_session_method_pref(session, ssh_methods.GetKey(0), ~GetMethodNames(0));
		if(rc == 0)	{
			ssh_methods.Remove(0);
			return ssh_methods.GetCount() > 0;
		}
		if(rc != LIBSSH2_ERROR_EAGAIN) Error();
		return true;
	};

	// Start handshake
	AddJob() << [=] {
		auto rc = libssh2_session_handshake(session, socket.GetSOCKET());
		if(rc == 0) {
			LLOG("++ SSH2: Handshake succesfull.");
			connected = true;
			return false;
		}
		if(rc != LIBSSH2_ERROR_EAGAIN) Error();
		return true;
	};

	// Get host key hash (fingerprint)
	AddJob() << [=] {
		fingerprint = libssh2_hostkey_hash(session, LIBSSH2_HOSTKEY_HASH_SHA1);
		if(fingerprint.IsEmpty())
			LLOG("!! SSH2: Fingerprint is not available.");
		DUMPHEX(fingerprint);
		if(WhenVerify && !WhenVerify(*this)) Error();
		return false;
	};

	// Retrieve authorization methods
	AddJob() << [=] {
		auth_methods = libssh2_userauth_list(session, username, username.GetLength());
		if(!auth_methods.IsEmpty()) {
			LLOG("++ SSH2: Authentication methods successfully retrieved.");
			CheckAuthMethods();
			WhenAuth(*this);
			return false;
		}
		if(!WouldBlock()) Error();
		return true;
	};

	// Authenticate user
	AddJob() << [=] {
		int rc = -1;
		switch(auth_method) {
			case PASSWORD:  rc = DoPwdAuth(); break;
			case PUBLICKEY: rc = DoKeyAuth(); break;
			case KEYBOARD:  rc = DoKbdAuth(); break;
			default: NEVER();
		}
		if(rc == 0 && libssh2_userauth_authenticated(session)) {
			LLOG("++ SSH2: Client succesfully authenticated and connected.");
			return false;
		}
		if(rc != LIBSSH2_ERROR_EAGAIN) Error();
		return true;
	};
	return *this;
}

Ssh& Ssh::StartDisconnect()
{
	// Stop SSH session
	AddJob() << [=] {
		if(session && connected) {
			int rc = libssh2_session_disconnect(session, "Disconnecting...");
			if(rc == 0) return connected = false;
			return rc == LIBSSH2_ERROR_EAGAIN;
		}
		return false;
	};
	// Free libssh2 session handle
	AddJob() << [=] {
		if(!session) {
			LLOG("!! SSH2: No session handle found. Couldn't release resources.");
			return false;
		}
		auto rc = libssh2_session_free(session);
		if(rc == 0) {
			LLOG("++ SSH2: Client disconnected, and resources released.");
			session = NULL;
			return false;
		}
		if(rc != LIBSSH2_ERROR_EAGAIN) Error();
		return true;
	};
	return *this;
}

String Ssh::GetMethodNames(int type)
{
	String methods;
	const auto& v = ssh_methods[type];
	for(int i = 0; i < v.GetCount(); i++)
		methods << v[i].To<String>() << (i < v.GetCount() - 1 ? "," : "");
	return pick(methods);
}

ValueMap Ssh::GetMethods()
{
	ValueMap methods;
	if(session) {
		for(int i = METHOD_EXCHANGE; i < METHOD_SCOMPRESSION; i++) {
			const char **p = NULL;
			auto rc = libssh2_session_supported_algs(session, i, &p);
			if(rc > 0) {
				auto& v = methods(i);
				for(int j = 0; j < rc; j++) {
					v << p[j];
				}
				libssh2_free(session, p);
			}
		}
	}
	return pick(methods);
}

void Ssh::CheckAuthMethods()
{
	auto methods = GetAuthMethods();
	for(auto& m : methods)
		if(m.IsEqual("password")             ||
		   m.IsEqual("keyboard-interactive") ||
		   m.IsEqual("publickey")
		 ) return;
	Error(-1, t_("No valid authentication method found."));
}

int Ssh::DoPwdAuth()
{
	return libssh2_userauth_password(session, username, password);
}

int Ssh::DoKeyAuth()
{
	return libssh2_userauth_publickey_fromfile(session, username, public_key, private_key, password);
}

int Ssh::DoKbdAuth()
{
	return libssh2_userauth_keyboard_interactive(session, ~username, &ssh_keyboard_callback);
}

void Ssh::CleanUp()
{
	if(session) StartDisconnect();
	if(!IsCleanup()) Execute();
	else LLOG("** SSH2: Performing clean up...");
}

void Ssh::Error(int code, const char* msg)
{
	Tuple<int, String> t = ssh_liberror(session, code, msg);
	Halt(t.Get<int>(), t.Get<String>());
}

Ssh::Ssh()
{
	session = NULL;
	connected = false;
	GlobalTimeout(60000);
	auth_method = PASSWORD;
	WhenHalt = [=] { CleanUp(); };
}

Ssh::Subsystem::Subsystem()
{
	type = Type::UNDEFINED;
	ssh = NULL;
	chunk_size = 65536;
	packet_length = int64(0);
	WhenHalt = [=] { CleanUp(); };
}

void Ssh::Subsystem::Session(Ssh& session)
{
	ssh = &session;
	WhenDo = [=] { ssh->WhenDo(); };
	StartInit();
}

void Ssh::Subsystem::Clear()
{
	if(!QueueIsEmpty()) {
		if(GetJob(0) == &Subsystem::StartInit) {
			while(GetJobCount() > 1)
				RemoveJob(1);
		}
		else ClearQueue();
	}
}

void Ssh::Subsystem::Error(int code, const char* msg)
{
	ASSERT(ssh);
	auto t = ssh_liberror(ssh->GetSession(), code, msg);
	Halt(t.Get<int>(), t.Get<String>());
}

Tuple<int, String> ssh_liberror(SshSession* session, int code, const char* msg)
{
	Tuple<int, String> t;
	if(session && code == 0) {
		Buffer<char*> libmsg(512);
		int rc = libssh2_session_last_error(session, libmsg, NULL, 0);
		t = MakeTuple<int, String>(rc, *libmsg);
	}
	else
		t = MakeTuple<int, String>(code, msg);
	LLOG(Format("-- SSH2: Error (code: %d): %s", t.Get<int>(), t.Get<String>()));
	return pick(t);
}

INITIALIZER(SSH) {
	LLOG("Initializing libssh2...");
	libssh2_init(0);
}
EXITBLOCK {
	LLOG("Deinitializing libssh2...");
	libssh2_exit();
}
}